import atexit
import os
import sys
import tarfile
import tempfile
from contextlib import ExitStack
from os.path import exists
from urllib.request import urlopen

if sys.version_info >= (3, 9):
    from importlib.resources import as_file, files
else:
    from importlib_resources import as_file, files

try:
    from .version import __version__  # NOQA
except ImportError:
    raise ImportError("BUG: version.py doesn't exist. Please file a bug report.")

from .htsengine import HTSEngine
from .openjtalk import OpenJTalk
from .openjtalk import mecab_dict_index as _mecab_dict_index
from .utils import merge_njd_marine_features

_file_manager = ExitStack()
atexit.register(_file_manager.close)

_pyopenjtalk_ref = files(__name__)
_dic_dir_name = "open_jtalk_dic_utf_8-1.11"

# Dictionary directory
# defaults to the package directory where the dictionary will be automatically downloaded
OPEN_JTALK_DICT_DIR = os.environ.get(
    "OPEN_JTALK_DICT_DIR",
    str(_file_manager.enter_context(as_file(_pyopenjtalk_ref / _dic_dir_name))),
).encode("utf-8")
_dict_download_url = "https://github.com/r9y9/open_jtalk/releases/download/v1.11.1"
_DICT_URL = f"{_dict_download_url}/open_jtalk_dic_utf_8-1.11.tar.gz"

# Default mei_normal.voice for HMM-based TTS
DEFAULT_HTS_VOICE = str(
    _file_manager.enter_context(
        as_file(_pyopenjtalk_ref / "htsvoice/mei_normal.htsvoice")
    )
).encode("utf-8")

# Global instance of OpenJTalk
_global_jtalk = None
# Global instance of HTSEngine
# mei_normal.voice is used as default
_global_htsengine = None
# Global instance of Marine
_global_marine = None


def _extract_dic():
    from tqdm.auto import tqdm

    global OPEN_JTALK_DICT_DIR
    pyopenjtalk_dir = _file_manager.enter_context(as_file(_pyopenjtalk_ref))
    with tempfile.TemporaryFile() as t:
        print('Downloading: "{}"'.format(_DICT_URL))
        with urlopen(_DICT_URL) as response:
            with tqdm.wrapattr(
                t, "write", total=getattr(response, "length", None)
            ) as tar:
                for chunk in response:
                    tar.write(chunk)
        t.seek(0)
        print("Extracting tar file")
        with tarfile.open(mode="r|gz", fileobj=t) as f:
            f.extractall(path=pyopenjtalk_dir)
    OPEN_JTALK_DICT_DIR = str(pyopenjtalk_dir / _dic_dir_name).encode("utf-8")


def _lazy_init():
    if not exists(OPEN_JTALK_DICT_DIR):
        _extract_dic()


def g2p(*args, **kwargs):
    """Grapheme-to-phoeneme (G2P) conversion

    This is just a convenient wrapper around `run_frontend`.

    Args:
        text (str): Unicode Japanese text.
        kana (bool): If True, returns the pronunciation in katakana, otherwise in phone.
          Default is False.
        join (bool): If True, concatenate phones or katakana's into a single string.
          Default is True.

    Returns:
        str or list: G2P result in 1) str if join is True 2) list if join is False.
    """
    global _global_jtalk
    if _global_jtalk is None:
        _lazy_init()
        _global_jtalk = OpenJTalk(dn_mecab=OPEN_JTALK_DICT_DIR)
    return _global_jtalk.g2p(*args, **kwargs)


def estimate_accent(njd_features):
    """Accent estimation using marine

    This function requires marine (https://github.com/6gsn/marine)

    Args:
        njd_result (list): features generated by OpenJTalk.

    Returns:
        list: features for NJDNode with estimation results by marine.
    """
    global _global_marine
    if _global_marine is None:
        try:
            from marine.predict import Predictor
        except BaseException:
            raise ImportError(
                "Please install marine by `pip install pyopenjtalk[marine]`"
            )
        _global_marine = Predictor()
    from marine.utils.openjtalk_util import convert_njd_feature_to_marine_feature

    marine_feature = convert_njd_feature_to_marine_feature(njd_features)
    marine_results = _global_marine.predict(
        [marine_feature], require_open_jtalk_format=True
    )
    njd_features = merge_njd_marine_features(njd_features, marine_results)
    return njd_features


def extract_fullcontext(text, run_marine=False):
    """Extract full-context labels from text

    Args:
        text (str): Input text
        run_marine (bool): Whether to estimate accent using marine.
          Default is False. If you want to activate this option, you need to install marine
          by `pip install pyopenjtalk[marine]`

    Returns:
        list: List of full-context labels
    """

    njd_features = run_frontend(text)
    if run_marine:
        njd_features = estimate_accent(njd_features)
    return make_label(njd_features)


def synthesize(labels, speed=1.0, half_tone=0.0):
    """Run OpenJTalk's speech synthesis backend

    Args:
        labels (list): Full-context labels
        speed (float): speech speed rate. Default is 1.0.
        half_tone (float): additional half-tone. Default is 0.

    Returns:
        np.ndarray: speech waveform (dtype: np.float64)
        int: sampling frequency (defualt: 48000)
    """
    if isinstance(labels, tuple) and len(labels) == 2:
        labels = labels[1]

    global _global_htsengine
    if _global_htsengine is None:
        _global_htsengine = HTSEngine(DEFAULT_HTS_VOICE)
    sr = _global_htsengine.get_sampling_frequency()
    _global_htsengine.set_speed(speed)
    _global_htsengine.add_half_tone(half_tone)
    return _global_htsengine.synthesize(labels), sr


def tts(text, speed=1.0, half_tone=0.0, run_marine=False):
    """Text-to-speech

    Args:
        text (str): Input text
        speed (float): speech speed rate. Default is 1.0.
        half_tone (float): additional half-tone. Default is 0.
        run_marine (bool): Whether to estimate accent using marine.
          Default is False. If you want activate this option, you need to install marine
          by `pip install pyopenjtalk[marine]`

    Returns:
        np.ndarray: speech waveform (dtype: np.float64)
        int: sampling frequency (defualt: 48000)
    """
    return synthesize(
        extract_fullcontext(text, run_marine=run_marine), speed, half_tone
    )


def run_frontend(text):
    """Run OpenJTalk's text processing frontend

    Args:
        text (str): Unicode Japanese text.

    Returns:
        list: features for NJDNode.
    """
    global _global_jtalk
    if _global_jtalk is None:
        _lazy_init()
        _global_jtalk = OpenJTalk(dn_mecab=OPEN_JTALK_DICT_DIR)
    return _global_jtalk.run_frontend(text)


def make_label(njd_features):
    """Make full-context label using features

    Args:
        njd_features (list): features for NJDNode.

    Returns:
        list: full-context labels.
    """
    global _global_jtalk
    if _global_jtalk is None:
        _lazy_init()
        _global_jtalk = OpenJTalk(dn_mecab=OPEN_JTALK_DICT_DIR)
    return _global_jtalk.make_label(njd_features)


def mecab_dict_index(path, out_path, dn_mecab=None):
    """Create user dictionary

    Args:
        path (str): path to user csv
        out_path (str): path to output dictionary
        dn_mecab (optional. str): path to mecab dictionary
    """
    global _global_jtalk
    if _global_jtalk is None:
        _lazy_init()
    if not exists(path):
        raise FileNotFoundError("no such file or directory: %s" % path)
    if dn_mecab is None:
        dn_mecab = OPEN_JTALK_DICT_DIR
    r = _mecab_dict_index(dn_mecab, path.encode("utf-8"), out_path.encode("utf-8"))

    # NOTE: mecab load returns 1 if success, but mecab_dict_index return the opposite
    # yeah it's confusing...
    if r != 0:
        raise RuntimeError("Failed to create user dictionary")


def update_global_jtalk_with_user_dict(path):
    """Update global openjtalk instance with the user dictionary

    Note that this will change the global state of the openjtalk module.

    Args:
        path (str): path to user dictionary
    """
    global _global_jtalk
    if _global_jtalk is None:
        _lazy_init()
    if not exists(path):
        raise FileNotFoundError("no such file or directory: %s" % path)
    _global_jtalk = OpenJTalk(
        dn_mecab=OPEN_JTALK_DICT_DIR, userdic=path.encode("utf-8")
    )
